{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.RNN_FCNPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN_FCNPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@timeseriesAI.co"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _RNN_FCN_Base_Backbone(Module):\n",
    "    def __init__(self, _cell, c_in, c_out, seq_len=None, hidden_size=100, rnn_layers=1, bias=True, cell_dropout=0, rnn_dropout=0.8, bidirectional=False, \n",
    "                 shuffle=True, conv_layers=[128, 256, 128], kss=[7, 5, 3], se=0):\n",
    "        \n",
    "        # RNN - first arg is usually c_in. Authors modified this to seq_len by not permuting x. This is what they call shuffled data.\n",
    "        self.rnn = _cell(seq_len if shuffle else c_in, hidden_size, num_layers=rnn_layers, bias=bias, batch_first=True, \n",
    "                              dropout=cell_dropout, bidirectional=bidirectional)\n",
    "        self.rnn_dropout = nn.Dropout(rnn_dropout) if rnn_dropout else noop\n",
    "        self.shuffle = Permute(0,2,1) if not shuffle else noop # You would normally permute x. Authors did the opposite.\n",
    "        \n",
    "        # FCN\n",
    "        assert len(conv_layers) == len(kss)\n",
    "        self.convblock1 = ConvBlock(c_in, conv_layers[0], kss[0])\n",
    "        self.se1 = SqueezeExciteBlock(conv_layers[0], se) if se != 0 else noop\n",
    "        self.convblock2 = ConvBlock(conv_layers[0], conv_layers[1], kss[1])\n",
    "        self.se2 = SqueezeExciteBlock(conv_layers[1], se) if se != 0 else noop\n",
    "        self.convblock3 = ConvBlock(conv_layers[1], conv_layers[2], kss[2])\n",
    "        self.gap = GAP1d(1)\n",
    "        \n",
    "        # Common\n",
    "        self.concat = Concat()\n",
    "        \n",
    "    def forward(self, x):  \n",
    "        # RNN\n",
    "        rnn_input = self.shuffle(x) # permute --> (batch_size, seq_len, n_vars) when batch_first=True\n",
    "        output, _ = self.rnn(rnn_input)\n",
    "        last_out = output[:, -1] # output of last sequence step (many-to-one)\n",
    "        last_out = self.rnn_dropout(last_out)\n",
    "        \n",
    "        # FCN\n",
    "        x = self.convblock1(x)\n",
    "        x = self.se1(x)\n",
    "        x = self.convblock2(x)\n",
    "        x = self.se2(x)\n",
    "        x = self.convblock3(x)\n",
    "        x = self.gap(x)\n",
    "\n",
    "        # Concat\n",
    "        x = self.concat([last_out, x])\n",
    "        return x\n",
    "\n",
    "        \n",
    "class _RNN_FCN_BasePlus(nn.Sequential):\n",
    "    def __init__(self, c_in, c_out, seq_len=None, d=None, hidden_size=100, rnn_layers=1, bias=True, cell_dropout=0, rnn_dropout=0.8, bidirectional=False, \n",
    "                 shuffle=True, fc_dropout=0., use_bn=False, conv_layers=[128, 256, 128], kss=[7, 5, 3], se=0, custom_head=None):\n",
    "        \n",
    "        if shuffle: assert seq_len is not None, 'need seq_len if shuffle=True'\n",
    "        conv_layers = listify(conv_layers)\n",
    "        \n",
    "        backbone = _RNN_FCN_Base_Backbone(self._cell, c_in, c_out, seq_len=seq_len, hidden_size=hidden_size, rnn_layers=rnn_layers, bias=bias, \n",
    "                                          cell_dropout=cell_dropout, rnn_dropout=rnn_dropout, bidirectional=bidirectional, shuffle=shuffle, \n",
    "                                          conv_layers=conv_layers, kss=kss, se=se)\n",
    "        \n",
    "        self.head_nf = hidden_size * (1 + bidirectional) + conv_layers[-1] \n",
    "        if custom_head:\n",
    "            if isinstance(custom_head, nn.Module): head = custom_head\n",
    "            else: head = custom_head(self.head_nf, c_out, seq_len)\n",
    "        else:\n",
    "            if d is None: mult = 1\n",
    "            elif isinstance(d, int): \n",
    "                mult = d\n",
    "                shape = (mult, c_out)\n",
    "            else:\n",
    "                mult = 1\n",
    "                for e in d: mult *= e\n",
    "                shape = tuple([e for e in d] + [c_out])\n",
    "            head_layers = []\n",
    "            if use_bn:\n",
    "                head_layers += [nn.BatchNorm1d(self.head_nf)]\n",
    "            if fc_dropout: head_layers += [nn.Dropout(fc_dropout)]\n",
    "            head_layers += [nn.Linear(self.head_nf, c_out * mult)] \n",
    "            if mult != 1:\n",
    "                head_layers += [Reshape(*shape)]\n",
    "            head = nn.Sequential(*head_layers)\n",
    "            \n",
    "        \n",
    "        layers = OrderedDict([('backbone', backbone), ('head', head)])\n",
    "        super().__init__(layers)\n",
    "\n",
    "\n",
    "class RNN_FCNPlus(_RNN_FCN_BasePlus):\n",
    "    _cell = nn.RNN\n",
    "    \n",
    "class LSTM_FCNPlus(_RNN_FCN_BasePlus):\n",
    "    _cell = nn.LSTM\n",
    "    \n",
    "class GRU_FCNPlus(_RNN_FCN_BasePlus):\n",
    "    _cell = nn.GRU\n",
    "    \n",
    "class MRNN_FCNPlus(_RNN_FCN_BasePlus):\n",
    "    _cell = nn.RNN\n",
    "    def __init__(self, *args, se=16, **kwargs):\n",
    "        super().__init__(*args, se=se, **kwargs)\n",
    "    \n",
    "class MLSTM_FCNPlus(_RNN_FCN_BasePlus):\n",
    "    _cell = nn.LSTM\n",
    "    def __init__(self, *args, se=16, **kwargs):\n",
    "        super().__init__(*args, se=se, **kwargs)\n",
    "    \n",
    "class MGRU_FCNPlus(_RNN_FCN_BasePlus):\n",
    "    _cell = nn.GRU\n",
    "    def __init__(self, *args, se=16, **kwargs):\n",
    "        super().__init__(*args, se=se, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.utils import count_parameters\n",
    "from tsai.models.RNN_FCN import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "n_vars = 3\n",
    "seq_len = 12\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, n_vars, seq_len)\n",
    "test_eq(RNN_FCNPlus(n_vars, c_out, seq_len)(xb).shape, [bs, c_out])\n",
    "test_eq(LSTM_FCNPlus(n_vars, c_out, seq_len)(xb).shape, [bs, c_out])\n",
    "test_eq(MLSTM_FCNPlus(n_vars, c_out, seq_len)(xb).shape, [bs, c_out])\n",
    "test_eq(GRU_FCNPlus(n_vars, c_out, shuffle=False)(xb).shape, [bs, c_out])\n",
    "test_eq(GRU_FCNPlus(n_vars, c_out, seq_len, shuffle=False)(xb).shape, [bs, c_out])\n",
    "test_eq(count_parameters(LSTM_FCNPlus(n_vars, c_out, seq_len)), count_parameters(LSTM_FCN(n_vars, c_out, seq_len)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "n_vars = 3\n",
    "seq_len = 12\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, n_vars, seq_len)\n",
    "custom_head = nn.Linear(228, c_out)\n",
    "test_eq(RNN_FCNPlus(n_vars, c_out, seq_len, custom_head=custom_head)(xb).shape, [bs, c_out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "n_vars = 3\n",
    "seq_len = 12\n",
    "d = 10\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, n_vars, seq_len)\n",
    "test_eq(RNN_FCNPlus(n_vars, c_out, seq_len, d=d)(xb).shape, [bs, d, c_out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "n_vars = 3\n",
    "seq_len = 12\n",
    "d = (5, 3)\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, n_vars, seq_len)\n",
    "test_eq(RNN_FCNPlus(n_vars, c_out, seq_len, d=d)(xb).shape, [bs, *d, c_out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LSTM_FCNPlus(\n",
       "  (backbone): _RNN_FCN_Base_Backbone(\n",
       "    (rnn): LSTM(2, 100, batch_first=True)\n",
       "    (rnn_dropout): Dropout(p=0.8, inplace=False)\n",
       "    (convblock1): ConvBlock(\n",
       "      (0): Conv1d(3, 128, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)\n",
       "      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU()\n",
       "    )\n",
       "    (se1): SqueezeExciteBlock(\n",
       "      (avg_pool): GAP1d(\n",
       "        (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "        (flatten): Reshape(bs)\n",
       "      )\n",
       "      (fc): Sequential(\n",
       "        (0): Linear(in_features=128, out_features=16, bias=False)\n",
       "        (1): ReLU()\n",
       "        (2): Linear(in_features=16, out_features=128, bias=False)\n",
       "        (3): Sigmoid()\n",
       "      )\n",
       "    )\n",
       "    (convblock2): ConvBlock(\n",
       "      (0): Conv1d(128, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU()\n",
       "    )\n",
       "    (se2): SqueezeExciteBlock(\n",
       "      (avg_pool): GAP1d(\n",
       "        (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "        (flatten): Reshape(bs)\n",
       "      )\n",
       "      (fc): Sequential(\n",
       "        (0): Linear(in_features=256, out_features=32, bias=False)\n",
       "        (1): ReLU()\n",
       "        (2): Linear(in_features=32, out_features=256, bias=False)\n",
       "        (3): Sigmoid()\n",
       "      )\n",
       "    )\n",
       "    (convblock3): ConvBlock(\n",
       "      (0): Conv1d(256, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU()\n",
       "    )\n",
       "    (gap): GAP1d(\n",
       "      (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "      (flatten): Reshape(bs)\n",
       "    )\n",
       "    (concat): Concat(dim=1)\n",
       "  )\n",
       "  (head): Sequential(\n",
       "    (0): Linear(in_features=228, out_features=12, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "LSTM_FCNPlus(n_vars, seq_len, c_out, se=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/047_models.RNN_FCNPlus.ipynb saved at 2023-07-03 18:11:18\n",
      "Correct notebook to script conversion! 😃\n",
      "Monday 03/07/23 18:11:21 CEST\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
